{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Funnel Search with Matryoshka Embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div style='margin: auto; width: 50%;'><img src='funnel-search.png' width='100%'></div>\n",
    "When building efficient vector search systems, one key challenge is managing storage costs while maintaining acceptable latency and recall. Modern embedding models output vectors with hundreds or thousands of dimensions, creating significant storage and computational overhead for the raw vector and index.\n",
    "\n",
    "Traditionally, the storage requirements are reduced by applying a quantization or dimensionality reduction method just before building the index. For instance, we can save storage by lowering the precision using Product Quantization (PQ) or the number of dimensions using Principal Component Analysis (PCA). These methods analyze the entire vector set to find a more compact one that maintains the semantic relationships between vectors.\n",
    "\n",
    "While effective, these standard approaches reduce precision or dimensionality only once and at a single scale. But what if we could maintain multiple layers of detail simultaneously, like a pyramid of increasingly precise representations?\n",
    "\n",
    "Enter Matryoshka embeddings. Named after Russian nesting dolls (see illustration), these clever constructs embed multiple scales of representation within a single vector. Unlike traditional post-processing methods, Matryoshka embeddings learn this multi-scale structure during the initial training process. The result is remarkable: not only does the full embedding capture input semantics, but each nested subset prefix (first half, first quarter, etc.) provides a coherent, if less detailed, representation.\n",
    "\n",
    "In this notebook, we examine how to use Matryoshka embeddings with Milvus for semantic search. We illustrate an algorithm called \"funnel search\" that allows us to perform similarity search over a small subset of our embedding dimensions without a drastic drop in recall."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import functools\n",
    "\n",
    "from datasets import load_dataset\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pymilvus\n",
    "from pymilvus import MilvusClient\n",
    "from pymilvus import FieldSchema, CollectionSchema, DataType\n",
    "from sentence_transformers import SentenceTransformer\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Matryoshka Embedding Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Instead of using a standard embedding model such as [`sentence-transformers/all-MiniLM-L12-v2`](https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2), we use [a model from Nomic](https://huggingface.co/nomic-ai/nomic-embed-text-v1) trained especially to produce Matryoshka embeddings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<All keys matched successfully>\n"
     ]
    }
   ],
   "source": [
    "model = SentenceTransformer(\n",
    "    # Remove 'device='mps' if running on non-Mac device\n",
    "    \"nomic-ai/nomic-embed-text-v1.5\",\n",
    "    trust_remote_code=True,\n",
    "    device=\"mps\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading Dataset, Embedding Items, and Building Vector Database"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following code is a modification of that from the documentation page [\"Movie Search with Sentence Transformers and Milvus\"](https://milvus.io/docs/integrate_with_sentencetransformers.md). First, we load the dataset from HuggingFace. It contains around 35k entries, each corresponding to a movie having a Wikipedia article. We will use the `Title` and `PlotSummary` fields in this example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset({\n",
      "    features: ['Release Year', 'Title', 'Origin/Ethnicity', 'Director', 'Cast', 'Genre', 'Wiki Page', 'Plot', 'PlotSummary'],\n",
      "    num_rows: 34886\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "ds = load_dataset(\"vishnupriyavr/wiki-movie-plots-with-summaries\", split=\"train\")\n",
    "print(ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we connect to a Milvus Lite database, specify the data schema, and create a collection with this schema. We will store both the unnormalized embedding and the first sixth of the embedding in separate fields. The reason for this is that we need the first 1/6th of the Matryoshka embedding for performing a similarity search, and the remaining 5/6ths of the embeddings for reranking and improving the search results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_dim = 768\n",
    "search_dim = 128\n",
    "collection_name = \"movie_embeddings\"\n",
    "\n",
    "client = MilvusClient(uri=\"./wiki-movie-plots-matryoshka.db\")\n",
    "\n",
    "fields = [\n",
    "    FieldSchema(name=\"id\", dtype=DataType.INT64, is_primary=True, auto_id=True),\n",
    "    FieldSchema(name=\"title\", dtype=DataType.VARCHAR, max_length=256),\n",
    "    # First sixth of unnormalized embedding vector\n",
    "    FieldSchema(name=\"head_embedding\", dtype=DataType.FLOAT_VECTOR, dim=search_dim),\n",
    "    # Entire unnormalized embedding vector\n",
    "    FieldSchema(name=\"embedding\", dtype=DataType.FLOAT_VECTOR, dim=embedding_dim),\n",
    "]\n",
    "\n",
    "schema = CollectionSchema(fields=fields, enable_dynamic_field=False)\n",
    "client.create_collection(collection_name=collection_name, schema=schema)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "Milvus does not currently support searching over subsets of embeddings, so we break the embeddings into two parts: the head represents the initial subset of the vector to index and search, and the tail is the remainder. The model is trained for cosine distance similarity search, so we normalize the head embeddings. However, in order to calculate similarities for larger subsets later on, we need to store the norm of the head embedding, so we can unnormalize it before joining to the tail."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To perform search via the first 1/6th of the embedding, we will need to create a vector search index over the `head_embedding` field. Later on, we will compare the results of \"funnel search\" with a regular vector search, and so build a search index over the full embedding also.\n",
    "\n",
    " *Importantly, we use the `COSINE` rather than the `IP` distance metric, because otherwise we would need to keep track of the embedding norms, which would complicate the implementation (this will make more sense once the funnel search algorithm has been described).*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "index_params = client.prepare_index_params()\n",
    "index_params.add_index(\n",
    "    field_name=\"head_embedding\", index_type=\"FLAT\", metric_type=\"COSINE\"\n",
    ")\n",
    "index_params.add_index(field_name=\"embedding\", index_type=\"FLAT\", metric_type=\"COSINE\")\n",
    "client.create_index(collection_name, index_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we encode the plot summaries for all 35k movies and enter the corresponding embeddings in to the database."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 69/69 [05:57<00:00,  5.18s/it]\n"
     ]
    }
   ],
   "source": [
    "for batch in tqdm(ds.batch(batch_size=512)):\n",
    "    # This particular model requires us to prefix 'search_document:' to stored entities\n",
    "    plot_summary = [\"search_document: \" + x.strip() for x in batch[\"PlotSummary\"]]\n",
    "\n",
    "    # Output of embedding model is unnormalized\n",
    "    embeddings = model.encode(plot_summary, convert_to_tensor=True)\n",
    "    head_embeddings = embeddings[:, :search_dim]\n",
    "\n",
    "    data = [\n",
    "        {\n",
    "            \"title\": title,\n",
    "            \"head_embedding\": head.cpu().numpy(),\n",
    "            \"embedding\": embedding.cpu().numpy(),\n",
    "        }\n",
    "        for title, head, embedding in zip(batch[\"Title\"], head_embeddings, embeddings)\n",
    "    ]\n",
    "    res = client.insert(collection_name=collection_name, data=data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Performing Funnel Search\n",
    "Let's now implement a \"funnel search\" using the first 1/6th of the Matryoshka embedding dimensions. I have three movies in mind for retrieval and have produced my own plot summary for querying the database. We embed the queries, then perform a vector search on the `head_embedding` field, retrieving 128 result candidates."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "queries = [\n",
    "    \"An archaeologist searches for ancient artifacts while fighting Nazis.\",\n",
    "    \"A teenager fakes illness to get off school and have adventures with two friends.\",\n",
    "    \"A young couple with a kid look after a hotel during winter and the husband goes insane.\",\n",
    "]\n",
    "\n",
    "\n",
    "# Search the database based on input text\n",
    "def embed_search(data):\n",
    "    embeds = model.encode(data)\n",
    "    return [x for x in embeds]\n",
    "\n",
    "\n",
    "# This particular model requires us to prefix 'search_query:' to queries\n",
    "instruct_queries = [\"search_query: \" + q.strip() for q in queries]\n",
    "search_data = embed_search(instruct_queries)\n",
    "\n",
    "# Normalize head embeddings\n",
    "head_search = [x[:search_dim] for x in search_data]\n",
    "\n",
    "# Perform standard vector search on first sixth of embedding dimensions\n",
    "res = client.search(\n",
    "    collection_name=collection_name,\n",
    "    data=head_search,\n",
    "    anns_field=\"head_embedding\",\n",
    "    limit=128,\n",
    "    output_fields=[\"title\", \"head_embedding\", \"embedding\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At this point, we have performed search over a much smaller vector space, and are therefore likely to have lowered latency and lessened storage requirements for the index relative to search over the full space. Let's examine the top 5 matches for each query:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Query: An archaeologist searches for ancient artifacts while fighting Nazis.\n",
      "Results:\n",
      "\"Pimpernel\" Smith\n",
      "Black Hunters\n",
      "The Passage\n",
      "Counterblast\n",
      "Dominion: Prequel to the Exorcist\n",
      "\n",
      "Query: A teenager fakes illness to get off school and have adventures with two friends.\n",
      "Results:\n",
      "How to Deal\n",
      "Shorts\n",
      "Blackbird\n",
      "Valentine\n",
      "Unfriended\n",
      "\n",
      "Query: A young couple with a kid look after a hotel during winter and the husband goes insane.\n",
      "Results:\n",
      "Ghostkeeper\n",
      "Our Vines Have Tender Grapes\n",
      "The Ref\n",
      "Impact\n",
      "The House in Marsh Road\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for query, hits in zip(queries, res):\n",
    "    rows = [x[\"entity\"] for x in hits][:5]\n",
    "\n",
    "    print(\"Query:\", query)\n",
    "    print(\"Results:\")\n",
    "    for row in rows:\n",
    "        print(row[\"title\"].strip())\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see, recall has suffered as a consequence of truncating the embeddings during search. Funnel search fixes this with a clever trick: we can use the remainder of the embedding dimensions to rerank and prune our candidate list to recover retrieval performance without running any additional expensive vector searches.\n",
    "\n",
    "For ease of exposition of the funnel search algorithm, we convert the Milvus search hits for each query into a Pandas dataframe."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hits_to_dataframe(hits: pymilvus.client.abstract.Hits) -> pd.DataFrame:\n",
    "    \"\"\"\n",
    "    Convert a Milvus search result to a Pandas dataframe. This function is specific to our data schema.\n",
    "\n",
    "    \"\"\"\n",
    "    rows = [x[\"entity\"] for x in hits]\n",
    "    rows_dict = [\n",
    "        {\"title\": x[\"title\"], \"embedding\": torch.tensor(x[\"embedding\"])} for x in rows\n",
    "    ]\n",
    "    return pd.DataFrame.from_records(rows_dict)\n",
    "\n",
    "\n",
    "dfs = [hits_to_dataframe(hits) for hits in res]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, to perform funnel search we iterate over the increasingly larger subsets of the embeddings. At each iteration, we rerank the candidates according to the new similarities and prune some fraction of the lowest ranked ones.\n",
    "\n",
    "To make this concrete, from the previous step we have retrieved 128 candidates using 1/6 of the embedding and query dimensions. The first step in performing funnel search is to recalculate the similarities between the queries and candidates using *the first 1/3 of the dimensions*. The bottom 64 candidates are pruned. Then we repeat this process with *the first 2/3 of the dimensions*, and then *all of the dimensions*, successively pruning to 32 and 16 candidates."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# An optimized implementation would vectorize the calculation of similarity scores across rows (using a matrix)\n",
    "def calculate_score(row, query_emb=None, dims=768):\n",
    "    emb = F.normalize(row[\"embedding\"][:dims], dim=-1)\n",
    "    return (emb @ query_emb).item()\n",
    "\n",
    "\n",
    "# You could also add a top-K parameter as a termination condition\n",
    "def funnel_search(\n",
    "    df: pd.DataFrame, query_emb, scales=[256, 512, 768], prune_ratio=0.5\n",
    ") -> pd.DataFrame:\n",
    "    # Loop over increasing prefixes of the embeddings\n",
    "    for dims in scales:\n",
    "        # Query vector must be normalized for each new dimensionality\n",
    "        emb = torch.tensor(query_emb[:dims] / np.linalg.norm(query_emb[:dims]))\n",
    "\n",
    "        # Score\n",
    "        scores = df.apply(\n",
    "            functools.partial(calculate_score, query_emb=emb, dims=dims), axis=1\n",
    "        )\n",
    "        df[\"scores\"] = scores\n",
    "\n",
    "        # Re-rank\n",
    "        df = df.sort_values(by=\"scores\", ascending=False)\n",
    "\n",
    "        # Prune (in our case, remove half of candidates at each step)\n",
    "        df = df.head(int(prune_ratio * len(df)))\n",
    "\n",
    "    return df\n",
    "\n",
    "\n",
    "dfs_results = [\n",
    "    {\"query\": query, \"results\": funnel_search(df, query_emb)}\n",
    "    for query, df, query_emb in zip(queries, dfs, search_data)\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "An archaeologist searches for ancient artifacts while fighting Nazis. \n",
      " 0           \"Pimpernel\" Smith\n",
      "1               Black Hunters\n",
      "29    Raiders of the Lost Ark\n",
      "34             The Master Key\n",
      "51            My Gun Is Quick\n",
      "Name: title, dtype: object \n",
      "\n",
      "A teenager fakes illness to get off school and have adventures with two friends. \n",
      " 21               How I Live Now\n",
      "32     On the Edge of Innocence\n",
      "77             Bratz: The Movie\n",
      "4                    Unfriended\n",
      "108                  Simon Says\n",
      "Name: title, dtype: object \n",
      "\n",
      "A young couple with a kid look after a hotel during winter and the husband goes insane. \n",
      " 9         The Shining\n",
      "0         Ghostkeeper\n",
      "11     Fast and Loose\n",
      "7      Killing Ground\n",
      "12         Home Alone\n",
      "Name: title, dtype: object \n",
      "\n"
     ]
    }
   ],
   "source": [
    "for d in dfs_results:\n",
    "    print(d[\"query\"], \"\\n\", d[\"results\"][:5][\"title\"], \"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have been able to restore recall without performing any additional vector searches! Qualitatively, these results seem to have higher recall for \"Raiders of the Lost Ark\" and \"The Shining\" than the standard vector search in the tutorial, [\"Movie Search using Milvus and Sentence Transformers\"](https://milvus.io/docs/integrate_with_sentencetransformers.md), which uses a different embedding model. However, it is unable to find \"Ferris Bueller's Day Off\", which we will return to later in the notebook. (See the paper [Matryoshka Representation Learning](https://arxiv.org/abs/2205.13147) for more quantitative experiments and benchmarking.)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Comparing Funnel Search to Regular Search"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's compare the results of our funnel search to a standard vector search *on the same dataset with the same embedding model*. We perform a search on the full embeddings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Search on entire embeddings\n",
    "res = client.search(\n",
    "    collection_name=collection_name,\n",
    "    data=search_data,\n",
    "    anns_field=\"embedding\",\n",
    "    limit=5,\n",
    "    output_fields=[\"title\", \"embedding\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Query: An archaeologist searches for ancient artifacts while fighting Nazis.\n",
      "Results:\n",
      "\"Pimpernel\" Smith\n",
      "Black Hunters\n",
      "Raiders of the Lost Ark\n",
      "The Master Key\n",
      "My Gun Is Quick\n",
      "\n",
      "Query: A teenager fakes illness to get off school and have adventures with two friends.\n",
      "Results:\n",
      "A Walk to Remember\n",
      "Ferris Bueller's Day Off\n",
      "How I Live Now\n",
      "On the Edge of Innocence\n",
      "Bratz: The Movie\n",
      "\n",
      "Query: A young couple with a kid look after a hotel during winter and the husband goes insane.\n",
      "Results:\n",
      "The Shining\n",
      "Ghostkeeper\n",
      "Fast and Loose\n",
      "Killing Ground\n",
      "Home Alone\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for query, hits in zip(queries, res):\n",
    "    rows = [x[\"entity\"] for x in hits]\n",
    "\n",
    "    print(\"Query:\", query)\n",
    "    print(\"Results:\")\n",
    "    for row in rows:\n",
    "        print(row[\"title\"].strip())\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the exception of the results for \"A teenager fakes illness to get off school...\", the results under funnel search are almost identical to the full search, even though the funnel search was performed on a search space of 128 dimensions vs 768 dimensions for the regular one.\t"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Investigating Funnel Search Recall Failure for Ferris Bueller's Day Off"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Why didn't funnel search succeed in retrieving Ferris Bueller's Day Off? Let's examine whether or not it was in the original candidate list or was mistakenly filtered out."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "queries2 = [\n",
    "    \"A teenager fakes illness to get off school and have adventures with two friends.\"\n",
    "]\n",
    "\n",
    "\n",
    "# Search the database based on input text\n",
    "def embed_search(data):\n",
    "    embeds = model.encode(data)\n",
    "    return [x for x in embeds]\n",
    "\n",
    "\n",
    "instruct_queries = [\"search_query: \" + q.strip() for q in queries2]\n",
    "search_data2 = embed_search(instruct_queries)\n",
    "head_search2 = [x[:search_dim] for x in search_data2]\n",
    "\n",
    "# Perform standard vector search on subset of embeddings\n",
    "res = client.search(\n",
    "    collection_name=collection_name,\n",
    "    data=head_search2,\n",
    "    anns_field=\"head_embedding\",\n",
    "    limit=256,\n",
    "    output_fields=[\"title\", \"head_embedding\", \"embedding\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Query: A teenager fakes illness to get off school and have adventures with two friends.\n",
      "Row 228: Ferris Bueller's Day Off\n"
     ]
    }
   ],
   "source": [
    "for query, hits in zip(queries, res):\n",
    "    rows = [x[\"entity\"] for x in hits]\n",
    "\n",
    "    print(\"Query:\", queries2[0])\n",
    "    for idx, row in enumerate(rows):\n",
    "        if row[\"title\"].strip() == \"Ferris Bueller's Day Off\":\n",
    "            print(f\"Row {idx}: Ferris Bueller's Day Off\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that the issue was that the initial candidate list was not large enough, or rather, the desired hit is not similar enough to the query at the highest level of granularity. Changing it from `128` to `256` results in successful retrieval. *We should form a rule-of-thumb to set the number of candidates on a held-out set to empirically evaluate the trade-off between recall and latency.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "A teenager fakes illness to get off school and have adventures with two friends. \n",
      "       A Walk to Remember\n",
      "Ferris Bueller's Day Off\n",
      "          How I Live Now\n",
      "On the Edge of Innocence\n",
      "        Bratz: The Movie\n",
      "              Unfriended\n",
      "              Simon Says \n",
      "\n"
     ]
    }
   ],
   "source": [
    "dfs = [hits_to_dataframe(hits) for hits in res]\n",
    "\n",
    "dfs_results = [\n",
    "    {\"query\": query, \"results\": funnel_search(df, query_emb)}\n",
    "    for query, df, query_emb in zip(queries2, dfs, search_data2)\n",
    "]\n",
    "\n",
    "for d in dfs_results:\n",
    "    print(d[\"query\"], \"\\n\", d[\"results\"][:7][\"title\"].to_string(index=False), \"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Does the order matter? Prefix vs suffix embeddings."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model was trained to perform well matching recursively smaller prefixes of the embeddings. Does the order of the dimensions we use matter? For instance, could we also take subsets of the embeddings that are suffixes? In this experiment, we reverse the order of the dimensions in the Matryoshka embeddings and perform a funnel search."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
     ]
    }
   ],
   "source": [
    "client = MilvusClient(uri=\"./wikiplots-matryoshka-flipped.db\")\n",
    "\n",
    "fields = [\n",
    "    FieldSchema(name=\"id\", dtype=DataType.INT64, is_primary=True, auto_id=True),\n",
    "    FieldSchema(name=\"title\", dtype=DataType.VARCHAR, max_length=256),\n",
    "    FieldSchema(name=\"head_embedding\", dtype=DataType.FLOAT_VECTOR, dim=search_dim),\n",
    "    FieldSchema(name=\"embedding\", dtype=DataType.FLOAT_VECTOR, dim=embedding_dim),\n",
    "]\n",
    "\n",
    "schema = CollectionSchema(fields=fields, enable_dynamic_field=False)\n",
    "client.create_collection(collection_name=collection_name, schema=schema)\n",
    "\n",
    "index_params = client.prepare_index_params()\n",
    "index_params.add_index(\n",
    "    field_name=\"head_embedding\", index_type=\"FLAT\", metric_type=\"COSINE\"\n",
    ")\n",
    "client.create_index(collection_name, index_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 69/69 [05:50<00:00,  5.08s/it]\n"
     ]
    }
   ],
   "source": [
    "for batch in tqdm(ds.batch(batch_size=512)):\n",
    "    plot_summary = [\"search_document: \" + x.strip() for x in batch[\"PlotSummary\"]]\n",
    "\n",
    "    # Encode and flip embeddings\n",
    "    embeddings = model.encode(plot_summary, convert_to_tensor=True)\n",
    "    embeddings = torch.flip(embeddings, dims=[-1])\n",
    "    head_embeddings = embeddings[:, :search_dim]\n",
    "\n",
    "    data = [\n",
    "        {\n",
    "            \"title\": title,\n",
    "            \"head_embedding\": head.cpu().numpy(),\n",
    "            \"embedding\": embedding.cpu().numpy(),\n",
    "        }\n",
    "        for title, head, embedding in zip(batch[\"Title\"], head_embeddings, embeddings)\n",
    "    ]\n",
    "    res = client.insert(collection_name=collection_name, data=data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Normalize head embeddings\n",
    "\n",
    "flip_search_data = [\n",
    "    torch.flip(torch.tensor(x), dims=[-1]).cpu().numpy() for x in search_data\n",
    "]\n",
    "flip_head_search = [x[:search_dim] for x in flip_search_data]\n",
    "\n",
    "# Perform standard vector search on subset of embeddings\n",
    "res = client.search(\n",
    "    collection_name=collection_name,\n",
    "    data=flip_head_search,\n",
    "    anns_field=\"head_embedding\",\n",
    "    limit=128,\n",
    "    output_fields=[\"title\", \"head_embedding\", \"embedding\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "An archaeologist searches for ancient artifacts while fighting Nazis. \n",
      "       \"Pimpernel\" Smith\n",
      "          Black Hunters\n",
      "Raiders of the Lost Ark\n",
      "         The Master Key\n",
      "        My Gun Is Quick\n",
      "            The Passage\n",
      "        The Mole People \n",
      "\n",
      "A teenager fakes illness to get off school and have adventures with two friends. \n",
      "                       A Walk to Remember\n",
      "                          How I Live Now\n",
      "                              Unfriended\n",
      "Cirque du Freak: The Vampire's Assistant\n",
      "                             Last Summer\n",
      "                                 Contest\n",
      "                                 Day One \n",
      "\n",
      "A young couple with a kid look after a hotel during winter and the husband goes insane. \n",
      "         Ghostkeeper\n",
      "     Killing Ground\n",
      "Leopard in the Snow\n",
      "              Stone\n",
      "          Afterglow\n",
      "         Unfaithful\n",
      "     Always a Bride \n",
      "\n"
     ]
    }
   ],
   "source": [
    "dfs = [hits_to_dataframe(hits) for hits in res]\n",
    "\n",
    "dfs_results = [\n",
    "    {\"query\": query, \"results\": funnel_search(df, query_emb)}\n",
    "    for query, df, query_emb in zip(queries, dfs, flip_search_data)\n",
    "]\n",
    "\n",
    "for d in dfs_results:\n",
    "    print(\n",
    "        d[\"query\"],\n",
    "        \"\\n\",\n",
    "        d[\"results\"][:7][\"title\"].to_string(index=False, header=False),\n",
    "        \"\\n\",\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Recall is much poorer than funnel search or regular search as expected (the embedding model was trained by contrastive learning on prefixes of the embedding dimensions, not suffixes)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is a comparison of our search results across methods:\n",
    "<div style='margin: auto; width: 80%;'><img src='results-raiders-of-the-lost-ark.png' width='100%'></div>\n",
    "<div style='margin: auto; width: 100%;'><img src='results-ferris-buellers-day-off.png' width='100%'></div>\n",
    "<div style='margin: auto; width: 80%;'><img src='results-the-shining.png' width='100%'></div>\n",
    "We have shown how to use Matryoshka embeddings with Milvus for performing a more efficient semantic search algorithm called \"funnel search.\" We also explored the importance of the reranking and pruning steps of the algorithm, as well as a failure mode when the initial candidate list is too small. Finally, we discussed how the order of the dimensions is important when forming sub-embeddings - it must be in the same way for which the model was trained. Or rather, it is only because the model was trained in a certain way that prefixes of the embeddings are meaningful. Now you know how to implement Matryoshka embeddings and funnel search to reduce the storage costs of semantic search without sacrificing too much retrieval performance!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "milvus",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
